from diffusers import DDIMScheduler
from diffusers import StableDiffusionInpaintPipeline
from pipeline_sd_inpaint import SDInpaintPipeline
from PIL import Image
import torch
from PIL import ImageChops
import time

pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "../models/stable-diffusion-2-1-base",
    torch_dtype=torch.float32,
)


pipeline.to("cuda")
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator("cuda").manual_seed(42)



with open("Benchmark/text.txt", "r") as f:
    prompts_data = f.read().splitlines()

face_image_paths = ["Benchmark/CelebA-HQ-img/{}.jpg".format(i) for i in range(200)]
mask_image_paths = ["Benchmark/CelebA-HQ-bgmask/{}.jpg".format(i) for i in range(200)]

total_time = 0
# start_time = time.time()
for idx in range(200):
    face_image = Image.open(face_image_paths[idx])
    mask_image = Image.open(mask_image_paths[idx])

    seed_size = 256
    face_image = face_image.resize((seed_size, seed_size))
    mask_image = mask_image.resize((seed_size, seed_size))


    mask_image_gray = mask_image.convert('1')
    inverted_mask_image = ImageChops.invert(mask_image_gray)
    image = ImageChops.composite(face_image, Image.new('RGB', face_image.size, (128, 128, 128)), inverted_mask_image)
    x_offset = (512 - 256) // 2
    y_offset = 0

    background_image = Image.new("RGB", (512, 512), color=(128, 128, 128))
    background_mask = Image.new("RGB", (512, 512), color=(255, 255, 255))
    background_image.paste(image, (x_offset, y_offset))
    background_mask.paste(mask_image, (x_offset, y_offset))


    prompt = prompts_data[idx//2]

    start_time = time.time()
    image = pipeline(prompt=prompt, image=background_image, mask_image=background_mask,generator=generator).images[0]
    end_time = time.time()
    image_processing_time = end_time - start_time
    total_time += image_processing_time

    image.save(f"./test/{idx}.jpg")
      

average_time_per_image = total_time / 200

print(f"Total time for 200 images: {total_time} seconds")
print(f"Average time per image: {average_time_per_image} seconds")